{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d27ba5b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# type: ignore"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e80407c",
   "metadata": {},
   "source": [
    "# Fireworks Supervised Fine-Tuning\n",
    "\n",
    "This recipe allows TensorZero users to fine-tune open-source LLMs using their own data.\n",
    "Since TensorZero automatically logs all inferences and feedback, it is straightforward to fine-tune a model using your own data and any prompt you want.\n",
    "We follow the Fireworks [docs](https://docs.fireworks.ai/fine-tuning/fine-tuning-via-api) on fine-tuning a model.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40542c99",
   "metadata": {},
   "source": [
    "To get started:\n",
    "\n",
    "- Set the `TENSORZERO_CLICKHOUSE_URL`, `FIREWORKS_API_KEY`, and `FIREWORKS_ACCOUNT_ID` environment variable. See the `.env.example` file.\n",
    "- Update the following parameters:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56877706",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "load_dotenv()\n",
    "\n",
    "CLICKHOUSE_URL = os.getenv(\"TENSORZERO_CLICKHOUSE_URL\")\n",
    "FIREWORKS_API_KEY = os.getenv(\"FIREWORKS_API_KEY\")\n",
    "account_id = os.getenv(\"FIREWORKS_ACCOUNT_ID\")\n",
    "\n",
    "assert CLICKHOUSE_URL is not None, \"TENSORZERO_CLICKHOUSE_URL is not set\"\n",
    "assert FIREWORKS_API_KEY is not None, \"FIREWORKS_API_KEY is not set\"\n",
    "assert account_id is not None, \"FIREWORKS_ACCOUNT_ID is not set\"\n",
    "\n",
    "tensorzero_path = os.path.abspath(os.path.join(os.getcwd(), \"../../../\"))\n",
    "if tensorzero_path not in sys.path:\n",
    "    sys.path.append(tensorzero_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d83fa9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG_PATH = \"../../../examples/data-extraction-ner/config/tensorzero.toml\"\n",
    "\n",
    "FUNCTION_NAME = \"extract_entities\"\n",
    "\n",
    "METRIC_NAME = \"jaccard_similarity\"\n",
    "\n",
    "# The name of the variant to use to grab the templates used for fine-tuning\n",
    "TEMPLATE_VARIANT_NAME = \"gpt_4o_mini\"  # It's OK that this variant uses a different model than the one we're fine-tuning\n",
    "\n",
    "# If the metric is a float metric, you can set the threshold to filter the data\n",
    "FLOAT_METRIC_THRESHOLD = 0.5\n",
    "\n",
    "# Fraction of the data to use for validation\n",
    "VAL_FRACTION = 0.2\n",
    "\n",
    "# Number of epochs to train for\n",
    "NUM_EPOCHS = 1\n",
    "\n",
    "# Maximum number of samples to use for fine-tuning (for Fireworks, NUM_EPOCHS * MAX_SAMPLES should be <= 3,000,000)\n",
    "MAX_SAMPLES = 100_000\n",
    "\n",
    "# The name of the model to fine-tune (supported models: https://docs.fireworks.ai/fine-tuning/fine-tuning-models#supported-base-models)\n",
    "MODEL_NAME = \"accounts/fireworks/models/llama-v3p1-8b-instruct\"\n",
    "\n",
    "# At the time of writing, Fireworks does not support tool call content blocks in assistant messages. Or the tool role.\n",
    "# We will drop these invalid messages from the dataset by default.\n",
    "# You can set this to False to keep the invalid messages in the dataset.\n",
    "DROP_INVALID_MESSAGES = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dac4e18",
   "metadata": {},
   "outputs": [],
   "source": [
    "from time import sleep\n",
    "\n",
    "import toml\n",
    "from IPython.display import clear_output\n",
    "from tensorzero import (\n",
    "    FireworksSFTConfig,\n",
    "    FloatMetricFilter,\n",
    "    OptimizationJobStatus,\n",
    "    TensorZeroGateway,\n",
    ")\n",
    "\n",
    "from recipes.util import train_val_split"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06cfd900",
   "metadata": {},
   "source": [
    "Initialize the embedded TensorZero client\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a07d4ed9",
   "metadata": {},
   "outputs": [],
   "source": [
    "t0 = TensorZeroGateway.build_embedded(\n",
    "    config_file=CONFIG_PATH,\n",
    "    clickhouse_url=CLICKHOUSE_URL,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62933c5a",
   "metadata": {},
   "source": [
    "Query for stored examples\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2441ef96",
   "metadata": {},
   "outputs": [],
   "source": [
    "filters = FloatMetricFilter(metric_name=METRIC_NAME, value=FLOAT_METRIC_THRESHOLD, comparison_operator=\">\")\n",
    "# from tensorzero import BooleanMetricFilter\n",
    "# filters = BooleanMetricFilter(metric_name=METRIC_NAME, value=True)\n",
    "# You could also train on demonstrations by changing the output_source to \"demonstration\"\n",
    "stored_samples = t0.experimental_list_inferences(\n",
    "    function_name=FUNCTION_NAME,\n",
    "    filters=filters,\n",
    "    output_source=\"inference\",\n",
    "    limit=MAX_SAMPLES,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88fe6727",
   "metadata": {},
   "source": [
    "Template the data using the variant we chose above.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ab26701",
   "metadata": {},
   "outputs": [],
   "source": [
    "rendered_samples = t0.experimental_render_samples(\n",
    "    stored_samples=stored_samples, variants={FUNCTION_NAME: TEMPLATE_VARIANT_NAME}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d150678e",
   "metadata": {},
   "source": [
    "Split the data into training and validation sets for fine-tuning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cb0fea2",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_samples, val_samples = train_val_split(\n",
    "    rendered_samples,\n",
    "    val_size=VAL_FRACTION,\n",
    "    last_inference_only=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13912945",
   "metadata": {},
   "source": [
    "Launch the fine tuning job"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a4b4051",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimization_config = FireworksSFTConfig(\n",
    "    model=MODEL_NAME,\n",
    "    account_id=account_id,\n",
    ")\n",
    "\n",
    "job_handle = t0.experimental_launch_optimization(\n",
    "    train_samples=train_samples,\n",
    "    val_samples=val_samples,\n",
    "    optimization_config=optimization_config,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12fba35b",
   "metadata": {},
   "source": [
    "Wait for the fine-tuning job to complete.\n",
    "\n",
    "This cell will take a while to run."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc5d0ac7",
   "metadata": {},
   "outputs": [],
   "source": [
    "while True:\n",
    "    clear_output(wait=True)\n",
    "\n",
    "    try:\n",
    "        job_info = t0.experimental_poll_optimization(job_handle=job_handle)\n",
    "        print(job_info)\n",
    "        if job_info.status in (\n",
    "            OptimizationJobStatus.Completed,\n",
    "            OptimizationJobStatus.Failed,\n",
    "        ):\n",
    "            break\n",
    "    except Exception as e:\n",
    "        print(f\"Error: {e}\")\n",
    "\n",
    "    sleep(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "115b1ff6",
   "metadata": {},
   "source": [
    "Once the fine-tuning job is complete, you can add the fine-tuned model to your config file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60f7d29b",
   "metadata": {},
   "outputs": [],
   "source": [
    "fine_tuned_model = job_info.output[\"routing\"][0]\n",
    "model_config = {\n",
    "    \"models\": {\n",
    "        fine_tuned_model: {\n",
    "            \"routing\": [\"fireworks\"],\n",
    "            \"providers\": {\"fireworks\": {\"type\": \"fireworks\", \"model_name\": fine_tuned_model}},\n",
    "        }\n",
    "    }\n",
    "}\n",
    "\n",
    "print(toml.dumps(model_config))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7f3c3ec",
   "metadata": {},
   "source": [
    "Finally, add a new variant to your function to use the fine-tuned model.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f215b20b",
   "metadata": {},
   "source": [
    "You're all set!\n",
    "\n",
    "You can change the weight to enable a gradual rollout of the new model.\n",
    "\n",
    "You might also add other parameters (e.g. `max_tokens`, `temperature`) to the variant section in the config file.\n"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "formats": "ipynb,py:percent",
   "main_language": "python"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
